/*
Copyright dyllen_zhong@qq.com
link: http://www.360us.net/

模拟ping程序
*/

package main

import (
	"bytes"
	"encoding/binary"
	"fmt"
	"net"
	"os"
	"strconv"
	"time"
)

//帮助提示
const HELP = "错误，用法如下：\nping [-c number] ip/domain\n\n-c number 指定执行次数，不指定默认4次\nip/domain 指定要ping的ip或者域名"

//icmp包结构
type icmp struct {
	Type       uint8
	Code       uint8
	Checksum   uint16
	Identifier uint16
	Sequence   uint16
}

func main() {
	checkInput()

	var dest string
	var numbers int = 0
	if len(os.Args) == 2 {
		dest = os.Args[1]
	} else if len(os.Args) == 4 {
		dest = os.Args[3]
		n, err := strconv.ParseInt(os.Args[2], 10, 0)
		if err != nil {
			fmt.Println(err)
			os.Exit(1)
		}
		numbers = int(n)
	}

	if numbers == 0 || numbers < 1 {
		numbers = 4
	}

	var (
		icmpPack icmp
		laddr    net.IPAddr = net.IPAddr{IP: net.ParseIP("0.0.0.0")} //源地址
		raddr, _            = net.ResolveIPAddr("ip", dest)          //目的地址
		response []byte     = make([]byte, 128)                      //保存响应数据
	)

	conn, err := net.DialIP("ip:icmp", &laddr, raddr) //建立ip连接
	if err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
	defer conn.Close()

	icmpPack.Type = 8
	icmpPack.Code = 0
	icmpPack.Checksum = 0 //计算Checksum之前置为0
	icmpPack.Identifier = 0
	icmpPack.Sequence = 0
	Data := []byte("123456abcdefghijklmnopqrstuvwxyz") //自定义数据

	var buffer bytes.Buffer //数据缓冲

	//计算校验和
	binary.Write(&buffer, binary.BigEndian, icmpPack) //写入ICMP头
	binary.Write(&buffer, binary.BigEndian, Data)     //写入自定义数据

	icmpPack.Checksum = checkSum(buffer.Bytes())

	buffer.Reset() //清空buffer

	//生成最终发送数据
	binary.Write(&buffer, binary.BigEndian, icmpPack) //写入ICMP头
	binary.Write(&buffer, binary.BigEndian, Data)     //写入自定义数据

	fmt.Printf("\n正 在Ping %s [%s] 具有%d字节的数据：\n", dest, raddr.String(), len(Data))

	var sent, receive int //已发送/已接收的数量统计
	var times []int       //记录每次完整请求花费的时间

	for c := 0; c < numbers; c++ {
		//发包
		_, err = conn.Write(buffer.Bytes())
		if err != nil {
			fmt.Println(err)
			os.Exit(1)
		}
		sent++

		t_start := time.Now() //请求开始时间

		conn.SetReadDeadline(time.Now().Add(time.Second * 5)) //设置读取超时时间点

		//读取响应，响应的slice里面，前20个字节是ip头，剩下的都是icmp包
		_, err = conn.Read(response)
		if err != nil {
			fmt.Println("请求超时")
			continue
		}

		/*
			计算响应的数据包的校验和，校验和不对丢弃
			1、把首部看成以16位为单位的数字组成，依次进行二进制反码求和，包括校验和字段；
			2、检查计算出的校验和的结果是否为0；
			3、如果等于0，说明被整除，校验和正确。否则，校验和就是错误的，协议栈要抛弃这个数据包。
		*/
		if checkSum(response[20:]) != 0 {
			continue
		}

		t_end := time.Now()                           //请求结束时间
		dur := t_end.Sub(t_start).Nanoseconds() / 1e6 //消耗时间

		receive++
		times = append(times, int(dur))

		fmt.Printf("来自 %s 的回复: 字节=%d 时间 = %dms TTL=%d\n", raddr.String(), len(Data), dur, response[8])

		//每次请求完成暂停一秒
		time.Sleep(time.Second)
	}

	result(raddr.String(), sent, receive, times) //展示最终结果统计信息
	os.Exit(0)
}

/*
求校验和步骤：
（1）把校验和字段置为0；
（2）把需校验的数据看成以16位为单位的数字组成，依次进行二进制反码求和；
（3）把得到的结果存入校验和字段中。
*/
func checkSum(data []byte) uint16 {
	var (
		sum    uint32
		length int = len(data)
		index  int
	)
	for length > 1 {
		sum += uint32(data[index])<<8 + uint32(data[index+1])
		index += 2
		length -= 2
	}
	if length > 0 {
		sum += uint32(data[index])
	}
	sum += (sum >> 16)

	return uint16(^sum)
}

//检查用户输入
func checkInput() {
	if len(os.Args) > 4 {
		fmt.Println(HELP)
		os.Exit(1)
	}
	switch len(os.Args) {
	case 1:
		fmt.Println(HELP)
		os.Exit(1)
	case 3:
		fmt.Println(HELP)
		os.Exit(1)
	case 4:
		_, err := strconv.ParseInt(os.Args[2], 10, 0)
		if (os.Args[1] != "-c") || (err != nil) {
			fmt.Println(HELP)
			os.Exit(1)
		}
	}
}

//结果统计
func result(addr string, sent, receive int, times []int) {
	var (
		format string  = "\n%s 的 Ping 统计信息：\n    数据包：已发送 = %d，已接收 = %d，丢失 = %d （%.1f%% 丢失），\n往返行程的估计时间（以毫秒为单位）：\n    最短 = %dms，最长 = %dms，平均 = %.0fms\n"
		min    int     //最短时间
		max    int     //最长时间
		count  int     //花费的总时间
		losted int     //丢包数
		t      int     //成功发送的包总数
		avg    float32 //平均时间
	)
	if len(times) > 0 {
		min, max, count, losted, t = times[0], times[0], 0, sent-receive, len(times)
	} else {
		min, max, count, losted, t = 0, 0, 0, sent-receive, len(times)
	}
	for k, v := range times {
		count += v
		if k == 0 {
			continue
		}
		if min > v {
			min = v
		}
		if max < v {
			max = v
		}
	}
	if t == 0 {
		avg = float32(0)
	} else {
		avg = float32(count) / float32(t)
	}
	fmt.Printf(format, addr, sent, receive, losted, float32(losted)/float32(sent)*100, min, max, avg)
}
